{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0aed74fd",
   "metadata": {},
   "source": [
    "Copyright (c) MONAI Consortium  \n",
    "Licensed under the Apache License, Version 2.0 (the \"License\");  \n",
    "you may not use this file except in compliance with the License.  \n",
    "You may obtain a copy of the License at  \n",
    "&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  \n",
    "Unless required by applicable law or agreed to in writing, software  \n",
    "distributed under the License is distributed on an \"AS IS\" BASIS,  \n",
    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  \n",
    "See the License for the specific language governing permissions and  \n",
    "limitations under the License.\n",
    "\n",
    "# UNet input size constraints\n",
    "\n",
    "MONAI provides an enhanced version of UNet (``monai.networks.nets.UNet ``), which not only supports residual units, but also can use more hyperparameters (like ``strides``, ``kernel_size`` and ``up_kernel_size``) than ``monai.networks.nets.BasicUNet``. However, ``UNet`` has some constraints for both network hyperparameters and sizes of input.\n",
    "\n",
    "The constraints of hyperparameters can be found in the docstring of the network, and this tutorial is focused on how to determine a reasonable input size.\n",
    "\n",
    "The last section: **Constraints of UNet** shows the conclusions."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "e9151e24",
   "metadata": {},
   "source": [
    "## Setup environments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "efcd04b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "!python -c \"import monai\" || pip install -q monai-weekly"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "a9c54384",
   "metadata": {},
   "source": [
    "## Setup imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "86ee1f12",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MONAI version: 0+untagged.2891.gccd32ca\n",
      "Numpy version: 1.25.1\n",
      "Pytorch version: 2.0.1\n",
      "MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False\n",
      "MONAI rev id: ccd32ca5e9e84562d2f388b45b6724b5c77c1f57\n",
      "MONAI __file__: /Users/<username>/Envs/monai/lib/python3.9/site-packages/monai/__init__.py\n",
      "\n",
      "Optional dependencies:\n",
      "Pytorch Ignite version: 0.4.11\n",
      "ITK version: 5.3.0\n",
      "Nibabel version: 5.1.0\n",
      "scikit-image version: 0.21.0\n",
      "scipy version: 1.11.1\n",
      "Pillow version: 10.0.0\n",
      "Tensorboard version: 2.13.0\n",
      "gdown version: 4.7.1\n",
      "TorchVision version: 0.15.2\n",
      "tqdm version: 4.65.0\n",
      "lmdb version: 1.4.1\n",
      "psutil version: 5.9.5\n",
      "pandas version: 2.0.3\n",
      "einops version: 0.6.1\n",
      "transformers version: 4.21.3\n",
      "mlflow version: 2.4.2\n",
      "pynrrd version: 1.0.0\n",
      "clearml version: 1.11.2rc0\n",
      "\n",
      "For details about installing the optional dependencies, please visit:\n",
      "    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from monai.networks.nets import UNet\n",
    "import monai\n",
    "import math\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "monai.config.print_config()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2f64140c",
   "metadata": {},
   "source": [
    "## Check UNet structure"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "30f9f2f7",
   "metadata": {},
   "source": [
    "The following comes from: [Left-Ventricle Quantification Using Residual U-Net](https://link.springer.com/chapter/10.1007/978-3-030-12029-0_40).\n",
    "\n",
    "![image](../figures/UNet_structure.png)\n",
    "\n",
    "First of all, let's build an UNet instance to check its structure. `num_res_units` is set to `0` since it has no impact on the input size."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "fd05bcb4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "UNet(\n",
       "  (model): Sequential(\n",
       "    (0): Convolution(\n",
       "      (conv): Conv3d(3, 8, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))\n",
       "      (adn): ADN(\n",
       "        (N): InstanceNorm3d(8, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)\n",
       "        (D): Dropout(p=0.0, inplace=False)\n",
       "        (A): PReLU(num_parameters=1)\n",
       "      )\n",
       "    )\n",
       "    (1): SkipConnection(\n",
       "      (submodule): Sequential(\n",
       "        (0): Convolution(\n",
       "          (conv): Conv3d(8, 16, kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(1, 1, 1))\n",
       "          (adn): ADN(\n",
       "            (N): InstanceNorm3d(16, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)\n",
       "            (D): Dropout(p=0.0, inplace=False)\n",
       "            (A): PReLU(num_parameters=1)\n",
       "          )\n",
       "        )\n",
       "        (1): SkipConnection(\n",
       "          (submodule): Convolution(\n",
       "            (conv): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))\n",
       "            (adn): ADN(\n",
       "              (N): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)\n",
       "              (D): Dropout(p=0.0, inplace=False)\n",
       "              (A): PReLU(num_parameters=1)\n",
       "            )\n",
       "          )\n",
       "        )\n",
       "        (2): Convolution(\n",
       "          (conv): ConvTranspose3d(48, 8, kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(1, 1, 1), output_padding=(2, 2, 2))\n",
       "          (adn): ADN(\n",
       "            (N): InstanceNorm3d(8, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)\n",
       "            (D): Dropout(p=0.0, inplace=False)\n",
       "            (A): PReLU(num_parameters=1)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (2): Convolution(\n",
       "      (conv): ConvTranspose3d(16, 3, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), output_padding=(1, 1, 1))\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "network_0 = UNet(\n",
    "    spatial_dims=3,\n",
    "    in_channels=3,\n",
    "    out_channels=3,\n",
    "    channels=(8, 16, 32),\n",
    "    strides=(2, 3),\n",
    "    kernel_size=3,\n",
    "    up_kernel_size=3,\n",
    "    num_res_units=0,\n",
    ")\n",
    "print(len(network_0.model))\n",
    "\n",
    "network_0"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9437ea49",
   "metadata": {},
   "source": [
    "As we can see from the printed structure, the network is consisted with three parts:\n",
    "\n",
    "1. The first down layer.\n",
    "2. The intermediate skip connection based block.\n",
    "3. The final up layer.\n",
    "\n",
    "If we want to build a deeper UNet, only the intermediate block will be expanded.\n",
    "\n",
    "During the network, there are only two different modules:\n",
    "1. `monai.networks.blocks.convolutions.Convolution`\n",
    "2. `monai.networks.layers.simplelayers.SkipConnection`\n",
    "\n",
    "All these modules are consisted with the following four layers:\n",
    "1. Activation layers (`PReLU`).\n",
    "2. Dropout layers (`Dropout`).\n",
    "3. Normalization layers (`InstanceNorm3d`).\n",
    "4. Convolution layers (`Conv` and `ConvTranspose`).\n",
    "\n",
    "As for the layers, convolution layers may change the size of the input, and normalization layers may have extra constraints of the input size.\n",
    "As for the modules, the `SkipConnection` module also has some constraints.\n",
    "\n",
    "Consequently, This tutorial shows the constraints of convolution layers, normalization layers and the `SkipConnection` module respectively."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bded0633",
   "metadata": {},
   "source": [
    "## Constraints of convolution layers"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c1f19415",
   "metadata": {},
   "source": [
    "### Conv layer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "072ed303",
   "metadata": {},
   "source": [
    "The formula in Pytorch's official docs explains how to calculate the output size for [Conv3d](https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html#torch.nn.Conv3d), and [ConvTranspose3d](https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html#torch.nn.ConvTranspose3d) (the formulas for `1d` and `2d` are similar)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4cef3f58",
   "metadata": {},
   "source": [
    "As the docs shown, the output size depends on the input size and:\n",
    "- `stride`\n",
    "- `kernel_size`\n",
    "- `dilation`\n",
    "- `padding`\n",
    "\n",
    "In `monai.networks.nets.UNet`, users can only input `strides` and `kernel_size`, and the other two parameters are decided by [monai.networks.blocks.convolutions.Convolution](https://github.com/Project-MONAI/MONAI/blob/dev/monai/networks/blocks/convolutions.py) (please click the link for details).\n",
    "\n",
    "Therefore, here `dilation = 1` and `padding = (kernel_size - 1) / 2` (`kernel_size` is required to be odd, thus here `padding` is an integer).\n",
    "\n",
    "The output size of `Conv` can be calculated via the following simplified formula:\n",
    "`math.floor((input_size + stride - 1) / stride)`. The corresponding python function is as follow, and we only need to ensure **`math.floor((input_size + stride - 1) / stride) >= 1`**, which means **`input_size >= 1`**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "37f7e0e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_conv_output_size(input_tensor, stride):\n",
    "    output_size = []\n",
    "    input_size = list(input_tensor.shape)[2:]\n",
    "    for size in input_size:\n",
    "        out = math.floor((size + stride - 1) / stride)\n",
    "        output_size.append(out)\n",
    "    print(output_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ba1c88b5",
   "metadata": {},
   "source": [
    "Let's check if the function is correct:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0d5b0d70",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1, 5, 10]\n"
     ]
    }
   ],
   "source": [
    "stride_value = 3\n",
    "example = torch.rand([1, 3, 1, 15, 29])\n",
    "get_conv_output_size(example, stride_value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "3b1b4388",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 5, 10])\n"
     ]
    }
   ],
   "source": [
    "output = nn.Conv3d(in_channels=3, out_channels=1, stride=stride_value, kernel_size=3, padding=1)(example)\n",
    "print(output.shape[2:])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a47f741e",
   "metadata": {},
   "source": [
    "### ConvTranspose layer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b458f329",
   "metadata": {},
   "source": [
    "Similarly, due to the default settings in [monai.networks.blocks.convolutions.Convolution](https://github.com/Project-MONAI/MONAI/blob/dev/monai/networks/blocks/convolutions.py), `output_padding = stride - 1`. The output size of `ConvTranspose` can be simplified as:\n",
    "`input_size * stride`.\n",
    "Therefore, before entering the `ConvTranspose` layer, we only need to ensure **`input_size >= 1`**.\n",
    "Let's check if the formula is correct:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "caece123",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[3, 45, 87]\n"
     ]
    }
   ],
   "source": [
    "stride_value = 3\n",
    "print([i * stride_value for i in example.shape[2:]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f67804d2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([3, 45, 87])\n"
     ]
    }
   ],
   "source": [
    "output = nn.ConvTranspose3d(\n",
    "    in_channels=3,\n",
    "    out_channels=1,\n",
    "    stride=stride_value,\n",
    "    kernel_size=3,\n",
    "    padding=1,\n",
    "    output_padding=stride_value - 1,\n",
    ")(example)\n",
    "print(output.shape[2:])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "391b93e6",
   "metadata": {},
   "source": [
    "## Constraints of normalization layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "dc4be9d5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('INSTANCE', 'BATCH', 'INSTANCE_NVFUSER', 'GROUP', 'LAYER', 'LOCALRESPONSE', 'SYNCBATCH')\n"
     ]
    }
   ],
   "source": [
    "print(monai.networks.layers.factories.Norm.names)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e47a8ef",
   "metadata": {},
   "source": [
    "In MONAI's norm factories, There are six normalization layers can be used. The official docs can be found in [here](https://pytorch.org/docs/stable/nn.html#normalization-layers), and their constraints is shown in [torch.nn.functional](https://pytorch.org/docs/stable/_modules/torch/nn/functional.html).\n",
    "\n",
    "However, the following normalization layers will not be discussed:\n",
    "1. SyncBatchNorm, since it only supports `DistributedDataParallel`, please check the official docs for more details.\n",
    "2. LayerNorm, since its parameter `normalized_shape` should equal to `[num_channels, *spatial_dims]`, and we cannot define a fixed value for it for all normalization layers in the network.\n",
    "3. GroupNorm, since its parameter `num_channels` should equal to the number of channels of the input, and we cannot define a fixed value for it for all normalization layers in the network.\n",
    "\n",
    "Therefore, let's check the other three normalization layers: batch normalization, instance normalization and local response normalization."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b611a564",
   "metadata": {},
   "source": [
    "### batch normalization\n",
    "\n",
    "The input size should meet: `torch.nn.functional._verify_batch_size`, and it requires the product of all dimensions except the channel dimension is larger than 1. For example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "732f2769",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch = nn.BatchNorm3d(num_features=3)\n",
    "for size in [[1, 3, 2, 1, 1], [2, 3, 1, 1, 1]]:\n",
    "    output = batch(torch.randn(size))\n",
    "\n",
    "# uncomment the following line you can see a ValueError\n",
    "# batch(torch.randn([1, 3, 1, 1, 1]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "07347476",
   "metadata": {},
   "source": [
    "In reality, when batch size is 1, it's not practical to use batch normalizaton. Therefore, the constraints can be converted to **the batch size should be larger than 1**."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "73c2b29f",
   "metadata": {},
   "source": [
    "### instance normalization\n",
    "\n",
    "The input size should meet: `torch.nn.functional._verify_spatial_size`, and it requires the product of all spatial dimensions is larger than 1. Therefore, **at least one spatial dimension should have a size larger than 1**. For example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "0a33cc15",
   "metadata": {},
   "outputs": [],
   "source": [
    "instance = nn.InstanceNorm3d(num_features=3)\n",
    "for size in [[1, 3, 2, 1, 1], [1, 3, 1, 2, 1]]:\n",
    "    output = instance(torch.randn(size))\n",
    "\n",
    "# uncomment the following line you can see a ValueError\n",
    "# instance(torch.randn([2, 3, 1, 1, 1]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "78f05589",
   "metadata": {},
   "source": [
    "### local response normalization\n",
    "\n",
    "**No constraint**. For example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "23d75904",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[[[-0.6150]]]]])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nn.LocalResponseNorm(size=1)(torch.randn(1, 1, 1, 1, 1))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bd830ec6",
   "metadata": {},
   "source": [
    "## Constraints of SkipConnection"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37f2aff4",
   "metadata": {},
   "source": [
    "In this section, we will check if the module [SkipConnection](https://github.com/Project-MONAI/MONAI/blob/dev/monai/networks/layers/simplelayers.py) itself has more constraints for the input size.\n",
    "\n",
    "In `UNet`, the `SkipConnection` is called via:\n",
    "\n",
    "`nn.Sequential(down, SkipConnection(subblock), up)`\n",
    "\n",
    "and the following line will be called (in forward function):\n",
    "\n",
    "`torch.cat([x, self.submodule(x)], dim=1)`. \n",
    "\n",
    "It requires for an input tensor, the output of `self.submodule` should not change spatial sizes. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ba8b9380",
   "metadata": {},
   "source": [
    "### When `len(channels) = 2` "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cdd7033e",
   "metadata": {},
   "source": [
    "If `len(channels) = 2`, there will only have one `SkipConnection` module in the network, and the module is built by a single down layer with `stride = 1`. From the formulas we achieved in the previous section, we know that this layer will not change the size, thus we only need to meet the constraints from the inside normalization layer:\n",
    "\n",
    "1. When using batch normalization, the batch size should larger than 1.\n",
    "\n",
    "2. When using instance normalization, size of at least one spatial dimension should larger than 1."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e391f534",
   "metadata": {},
   "source": [
    "### When `len(channels) > 2` "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2efce3e2",
   "metadata": {},
   "source": [
    "If `len(channels) > 2`, more `SkipConnection` module will be built and each of the module is consisted with one down layer and one up layer. Consequently, **the output of the up layer should has the same spatial sizes as the input before entering into the down layer**. The corresponding stride values for these modules are coming from `strides[1:]`, hence for each stride value `s` from `strides[1:]`, for each spatial size value `v` of the input, the constraint of the corresponding `SkipConnection` module is:\n",
    "\n",
    "```\n",
    "math.floor((v + s - 1) / s) = v / s\n",
    "\n",
    "```\n",
    "\n",
    "Since the left-hand side of the equation is a positive integer, `input_size` must be divisible by `stride`. If we assume `v = k * s` where `k >= 1`, we can get:\n",
    "```\n",
    "math.floor(k + (s - 1) / s) = k\n",
    "k + math.floor((s -1) / s) = k\n",
    "math.floor((s -1) / s) = 0\n",
    "```\n",
    "Obviously, the above equations are always true, thus for a single `SkipConnection` module, all spatial sizes of the input must be divisible by `s`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ae3eb93c",
   "metadata": {},
   "source": [
    "For the whole `SkipConnection` module, assume `[H, W, D]` is the input spatial size, then for `v in [H, W, D]`:\n",
    "\n",
    "**`np.remainder(v, np.prod(strides[1:])) == 0`**\n",
    "\n",
    "In addition, there may have more constraints from normalization layers:\n",
    "\n",
    "1. When using batch normalization, the batch size of the input should be larger than 1.\n",
    "\n",
    "2. When using instance normalization, size of at least one spatial dimension should larger than 1. Therefore, **assume `d = max(H, W, D)`, `d` should meet: `np.remainder(d, 2 * np.prod(strides[1:])) == 0`**."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8e2d99ef",
   "metadata": {},
   "source": [
    "## Constraints of UNet"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "554744bc",
   "metadata": {},
   "source": [
    "As the first section discussed, UNet is consisted with 1) a down layer, 2) one or mode skip connection module(s) and 3) an up layer. Based on the analyses for each single layer/module, the constraints of the network can be summarized as follow."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d7ae8cd7",
   "metadata": {},
   "source": [
    "### When `len(channels) = 2`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8cd1d3b5",
   "metadata": {},
   "source": [
    "If `len(channels) == 2`, `strides` must be a single value, thus assume `s = strides`, and the input size is `[B, C, H, W, D]`. The constraints are:\n",
    "\n",
    "1. If using batch normalization: **`B > 1`.**\n",
    "2. If using local response normalization: no constraint.\n",
    "3. If using instance normalization, assume `d = max(H, W, D)`, then `math.floor((d + s - 1) / s) >= 2`, which means **`d >= s + 1`.**\n",
    "\n",
    "The following are the corresponding examples:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "1bdc5c8e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([2, 3, 3, 3, 3])\n"
     ]
    }
   ],
   "source": [
    "# example 1: len(channels) = 2, batch norm, batch size > 1.\n",
    "network = UNet(\n",
    "    spatial_dims=3,\n",
    "    in_channels=1,\n",
    "    out_channels=3,\n",
    "    channels=(8, 16),\n",
    "    strides=(3,),\n",
    "    kernel_size=3,\n",
    "    up_kernel_size=3,\n",
    "    num_res_units=0,\n",
    "    norm=\"batch\",\n",
    ")\n",
    "example = torch.rand([2, 1, 1, 1, 1])\n",
    "print(network(example).shape)\n",
    "\n",
    "# # uncomment the following two lines will see the error\n",
    "# example = torch.rand([1, 1, 1, 1, 1])\n",
    "# print(network(example).shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "7485a83a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 3, 3, 3, 3])\n"
     ]
    }
   ],
   "source": [
    "# example 2: len(channels) = 2, localresponse.\n",
    "network = UNet(\n",
    "    spatial_dims=3,\n",
    "    in_channels=1,\n",
    "    out_channels=3,\n",
    "    channels=(8, 16),\n",
    "    strides=(3,),\n",
    "    kernel_size=1,\n",
    "    up_kernel_size=1,\n",
    "    num_res_units=1,\n",
    "    norm=(\"localresponse\", {\"size\": 1}),\n",
    ")\n",
    "example = torch.rand([1, 1, 1, 1, 1])\n",
    "print(network(example).shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "6a31861f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 3, 6, 3, 3])\n"
     ]
    }
   ],
   "source": [
    "# example 3: len(channels) = 2, instance norm.\n",
    "network = UNet(\n",
    "    spatial_dims=3,\n",
    "    in_channels=1,\n",
    "    out_channels=3,\n",
    "    channels=(8, 16),\n",
    "    strides=(3,),\n",
    "    kernel_size=3,\n",
    "    up_kernel_size=5,\n",
    "    num_res_units=2,\n",
    "    norm=\"instance\",\n",
    ")\n",
    "example = torch.rand([1, 1, 4, 1, 1])\n",
    "print(network(example).shape)\n",
    "\n",
    "# # uncomment the following two lines will see the error\n",
    "# example = torch.rand([1, 1, 1, 1, 3])\n",
    "# print(network(example).shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ba6057e8",
   "metadata": {},
   "source": [
    "### When `len(channels) > 2`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c804fa49",
   "metadata": {},
   "source": [
    "Assume the input size is `[B, C, H, W, D]`, and `s = strides`. The common constraints are:\n",
    "\n",
    "```\n",
    "For v in [H, W, D]:\n",
    "     size = math.floor((v + s[0] - 1) / s[0])\n",
    "     size should meet: np.remainder(size, np.prod(s[1:])) == 0\n",
    "```\n",
    "In addition,\n",
    "1. If using batch normalization: **`B > 1`.**\n",
    "2. If using local response normalization: no more constraint.\n",
    "3. If using instance normalization, then:\n",
    "```\n",
    "d = max(H, W, D)\n",
    "max_size = math.floor((d + s[0] - 1) / s[0])\n",
    "max_size should meet: np.remainder(max_size, 2 * np.prod(s[1:])) == 0\n",
    "```\n",
    "\n",
    "The following are the corresponding examples:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "d234f140",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([2, 3, 15, 15, 15])\n"
     ]
    }
   ],
   "source": [
    "# example 1: strides=(3, 5), batch norm, batch size > 1.\n",
    "# thus math.floor((v + 2) / 3) should be 5 * k. If k = 1, v should be in [13, 14, 15].\n",
    "network = UNet(\n",
    "    spatial_dims=3,\n",
    "    in_channels=1,\n",
    "    out_channels=3,\n",
    "    channels=(8, 16, 32),\n",
    "    strides=(3, 5),\n",
    "    kernel_size=3,\n",
    "    up_kernel_size=3,\n",
    "    num_res_units=0,\n",
    "    norm=\"batch\",\n",
    ")\n",
    "example = torch.rand([2, 1, 13, 14, 15])\n",
    "print(network(example).shape)\n",
    "\n",
    "# # uncomment the following two lines will see the error\n",
    "# example = torch.rand([1, 1, 12, 14, 15])\n",
    "# print(network(example).shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "ccf53aa1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 3, 24, 24, 24])\n"
     ]
    }
   ],
   "source": [
    "# example 2: strides=(3, 2, 4), localresponse.\n",
    "# thus math.floor((v + 2) / 3) should be 8 * k. If k = 1, v should be in [22, 23, 24].\n",
    "network = UNet(\n",
    "    spatial_dims=3,\n",
    "    in_channels=1,\n",
    "    out_channels=3,\n",
    "    channels=(8, 16, 32, 16),\n",
    "    strides=(3, 2, 4),\n",
    "    kernel_size=1,\n",
    "    up_kernel_size=3,\n",
    "    num_res_units=10,\n",
    "    norm=(\"localresponse\", {\"size\": 1}),\n",
    ")\n",
    "example = torch.rand([1, 1, 22, 23, 24])\n",
    "print(network(example).shape)\n",
    "\n",
    "# # uncomment the following two lines will see the error\n",
    "# example = torch.rand([1, 1, 25, 23, 24])\n",
    "# print(network(example).shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "da6e9277",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 3, 24, 12, 12])\n"
     ]
    }
   ],
   "source": [
    "# example 3: strides=(1, 2, 2, 3), instance norm.\n",
    "# thus v should be 12 * k. If k = 1, v should be 12. In addition, the maximum size should be 24 * k.\n",
    "\n",
    "network = UNet(\n",
    "    spatial_dims=3,\n",
    "    in_channels=1,\n",
    "    out_channels=3,\n",
    "    channels=(8, 16, 32, 32, 16),\n",
    "    strides=(1, 2, 2, 3),\n",
    "    kernel_size=5,\n",
    "    up_kernel_size=3,\n",
    "    num_res_units=5,\n",
    "    norm=\"instance\",\n",
    ")\n",
    "example = torch.rand([1, 1, 24, 12, 12])\n",
    "print(network(example).shape)\n",
    "\n",
    "# # uncomment the following two lines will see the error\n",
    "# example = torch.rand([1, 1, 12, 12, 12])\n",
    "# print(network(example).shape)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
